{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a3b44703",
   "metadata": {},
   "source": [
    "# Simple example"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "416b0a4b",
   "metadata": {},
   "source": [
    "Copyright 2016 The BigDL Authors."
   ]
  },
  {
   "cell_type": "raw",
   "id": "a474a628",
   "metadata": {},
   "source": [
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf304c91",
   "metadata": {},
   "source": [
    "SparkXshards in Orca allows users to process large-scale dataset using existing Python codes in a distributed and data-parallel fashion, as shown below. This notebook is an example of a simple deep learning project using keras and SparkXshards.\n",
    "\n",
    "It is adapted from [Your First Deep Learning Project in Python with Keras Step-by-Step](https://machinelearningmastery.com/tutorial-first-neural-network-python-keras) on diabetes data. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92f533d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import necessary libraries\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense\n",
    "\n",
    "import bigdl.orca.data.pandas\n",
    "from bigdl.orca import init_orca_context, stop_orca_context\n",
    "from bigdl.orca.learn.tf2.estimator import Estimator\n",
    "import warnings\n",
    "import math\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b6e8f16",
   "metadata": {},
   "source": [
    "Start an OrcaContext, default backend is using spark to read each file into a pandas dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdcb1053",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc = init_orca_context(memory=\"4g\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a46ad6a7",
   "metadata": {},
   "source": [
    "##  Load data in parallel and get general information"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "904060e4",
   "metadata": {},
   "source": [
    "Load data into data_shards, it is a SparkXshards that can be operated on in parallel, here each element of the data_shards is a panda dataframe read from a file on the cluster. Users can distribute local code of `pd.read_csv(dataFile)` using `bigdl.orca.data.pandas.read_csv(datapath)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9daad62",
   "metadata": {},
   "outputs": [],
   "source": [
    "datapath = '../pima-indians-diabetes.csv'\n",
    "data_shards = bigdl.orca.data.pandas.read_csv(datapath, header=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6d12d912",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6</td>\n",
       "      <td>148</td>\n",
       "      <td>72</td>\n",
       "      <td>35</td>\n",
       "      <td>0</td>\n",
       "      <td>33.6</td>\n",
       "      <td>0.627</td>\n",
       "      <td>50</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>85</td>\n",
       "      <td>66</td>\n",
       "      <td>29</td>\n",
       "      <td>0</td>\n",
       "      <td>26.6</td>\n",
       "      <td>0.351</td>\n",
       "      <td>31</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>8</td>\n",
       "      <td>183</td>\n",
       "      <td>64</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>23.3</td>\n",
       "      <td>0.672</td>\n",
       "      <td>32</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>89</td>\n",
       "      <td>66</td>\n",
       "      <td>23</td>\n",
       "      <td>94</td>\n",
       "      <td>28.1</td>\n",
       "      <td>0.167</td>\n",
       "      <td>21</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>137</td>\n",
       "      <td>40</td>\n",
       "      <td>35</td>\n",
       "      <td>168</td>\n",
       "      <td>43.1</td>\n",
       "      <td>2.288</td>\n",
       "      <td>33</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   0    1   2   3    4     5      6   7  8\n",
       "0  6  148  72  35    0  33.6  0.627  50  1\n",
       "1  1   85  66  29    0  26.6  0.351  31  0\n",
       "2  8  183  64   0    0  23.3  0.672  32  1\n",
       "3  1   89  66  23   94  28.1  0.167  21  0\n",
       "4  0  137  40  35  168  43.1  2.288  33  1"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# show the first couple of rows in the data_shards\n",
    "data_shards.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "62d45e04",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# see the num of partitions of data_shards\n",
    "data_shards.num_partitions()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3b8526d9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "768"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# count total number of rows in the data_shards\n",
    "len(data_shards)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07efecce",
   "metadata": {},
   "source": [
    "## Assemble feature and labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d8d0c7f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "columns = list(data_shards.get_schema()['columns'])\n",
    "data_shards = data_shards.assembleFeatureLabelCols(featureCols=columns[:-1],\n",
    "                                                 labelCols=list(columns[-1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd62a210",
   "metadata": {},
   "source": [
    "## Define Keras model and train it"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "175286e4",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "Build the model model as usual. Here, we'll use a Sequential model with two densely connected hidden layers, and an output layer that returns a single, continuous value.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a17921c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_creator(config):\n",
    "    model = Sequential()\n",
    "    model.add(Dense(12, input_shape=(8,), activation='relu'))\n",
    "    model.add(Dense(8, activation='relu'))\n",
    "    model.add(Dense(1, activation='sigmoid'))\n",
    "\n",
    "    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "047c0679",
   "metadata": {},
   "outputs": [],
   "source": [
    "est = Estimator.from_keras(model_creator=model_creator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d9832fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 16\n",
    "train_steps = math.ceil(len(data_shards) / batch_size)\n",
    "est.fit(data=data_shards,\n",
    "        batch_size=batch_size,\n",
    "        epochs=150,\n",
    "        steps_per_epoch=train_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "68ed8c57",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stopping orca context\n"
     ]
    }
   ],
   "source": [
    "stop_orca_context()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37tf2_x",
   "language": "python",
   "name": "py37tf2_x"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
